import logging
import queue
import threading
from queue import Queue
from typing import Callable, Dict, Mapping, Sequence

from common.agent_configuration.agent_sub_configurations import ExploitationConfiguration
from common.types import Event
from infection_monkey.i_puppet import ExploiterResult, IPuppet, RejectedRequestError, TargetHost
from infection_monkey.utils.threading import interruptible_iter, run_worker_threads

QUEUE_TIMEOUT = 2

logger = logging.getLogger()

ExploiterName = str
Callback = Callable[[ExploiterName, TargetHost, ExploiterResult], None]


class Exploiter:
    def __init__(
        self,
        puppet: IPuppet,
        num_workers: int,
    ):
        self._puppet = puppet
        self._num_workers = num_workers

    def exploit_hosts(
        self,
        exploitation_config: ExploitationConfiguration,
        hosts_to_exploit: Queue,
        current_depth: int,
        servers: Sequence[str],
        results_callback: Callback,
        scan_completed: threading.Event,
        stop: Event,
    ):
        exploiter_configs = self._process_exploiter_config(exploitation_config)
        logger.debug(
            "Agent is configured to run the following exploiters in order: "
            f"{', '.join([e for e in exploiter_configs])}"
        )

        exploit_args = (
            exploiter_configs,
            hosts_to_exploit,
            current_depth,
            servers,
            results_callback,
            scan_completed,
            stop,
        )
        run_worker_threads(
            target=self._exploit_hosts_on_queue,
            name_prefix="ExploiterThread",
            args=exploit_args,
            num_workers=self._num_workers,
        )

    @staticmethod
    def _process_exploiter_config(
        exploitation_config: ExploitationConfiguration,
    ) -> Dict[ExploiterName, Mapping]:
        # The order of configurations is the order exploiters will get run in
        extended_configs: Dict[str, Mapping] = {}
        for exploiter, exploiter_options in exploitation_config.exploiters.items():
            # This order allows exploiter-specific options to
            # override general options for all exploiters.
            options = {**exploitation_config.options.__dict__, **exploiter_options}
            extended_configs[exploiter] = options

        return extended_configs

    def _exploit_hosts_on_queue(
        self,
        exploiter_configs: Dict[ExploiterName, Mapping],
        hosts_to_exploit: Queue,
        current_depth: int,
        servers: Sequence[str],
        results_callback: Callback,
        scan_completed: threading.Event,
        stop: Event,
    ):
        logger.debug(f"Starting exploiter thread -- Thread ID: {threading.get_ident()}")

        while not stop.is_set():
            try:
                target_host = hosts_to_exploit.get(timeout=QUEUE_TIMEOUT)
                self._run_all_exploiters(
                    exploiter_configs, target_host, current_depth, servers, results_callback, stop
                )
            except queue.Empty:
                if _all_hosts_have_been_processed(scan_completed, hosts_to_exploit):
                    break

        logger.debug(
            f"Exiting exploiter thread -- Thread ID: {threading.get_ident()} -- "
            f"stop.is_set(): {stop.is_set()} -- network_scan_completed: "
            f"{scan_completed.is_set()}"
        )

    def _run_all_exploiters(
        self,
        exploiter_configs: Dict[ExploiterName, Mapping],
        target_host: TargetHost,
        current_depth: int,
        servers: Sequence[str],
        results_callback: Callback,
        stop: Event,
    ):
        for exploiter_name, exploiter_config in interruptible_iter(exploiter_configs.items(), stop):
            try:
                exploiter_results = self._run_exploiter(
                    exploiter_name,
                    exploiter_config,
                    target_host,
                    current_depth,
                    servers,
                    stop,
                )
            except RejectedRequestError:
                continue

            results_callback(exploiter_name, target_host, exploiter_results)

            if exploiter_results.propagation_success:
                break

    def _run_exploiter(
        self,
        exploiter_name: str,
        options: Mapping,
        target_host: TargetHost,
        current_depth: int,
        servers: Sequence[str],
        stop: Event,
    ) -> ExploiterResult:
        logger.debug(f"Attempting to use {exploiter_name} on {target_host.ip}")

        try:
            return self._puppet.exploit_host(
                exploiter_name, target_host, current_depth, servers, options, stop
            )
        except RejectedRequestError as err:
            logger.info(f"The request to exploit {target_host.ip} was rejected: {err}")
            raise err
        except Exception as err:
            msg = (
                f"An unexpected error occurred while exploiting {target_host.ip} with "
                f"{exploiter_name}: {err}"
            )
            logger.error(msg)
            logger.exception(err)
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )


def _all_hosts_have_been_processed(scan_completed: threading.Event, hosts_to_exploit: Queue):
    return scan_completed.is_set() and hosts_to_exploit.empty()
